import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend as K
from tensorflow.keras import activations
from tensorflow.keras.layers import Layer, Input, Embedding, LSTM, Dense, Attention
from tensorflow.keras.models import Model
import numpy as np



class Encoder(keras.Model):
    def __init__(self, vocab_size, embedding_dim, hidden_units):
        super(Encoder, self).__init__()
        # Embedding Layer
        self.embedding = Embedding(vocab_size, embedding_dim, mask_zero=True)
        # Encode LSTM Layer
        self.encoder_lstm = LSTM(hidden_units, return_sequences=True, return_state=True, name="encode_lstm")

    def call(self, inputs):
        encoder_embed = self.embedding(inputs)
        encoder_outputs, state_h, state_c = self.encoder_lstm(encoder_embed)
        return encoder_outputs, state_h, state_c


class Decoder(keras.Model):
    def __init__(self, vocab_size, embedding_dim, hidden_units):
        super(Decoder, self).__init__()
        # Embedding Layer
        self.embedding = Embedding(vocab_size, embedding_dim, mask_zero=True)
        # Decode LSTM Layer
        self.decoder_lstm = LSTM(hidden_units, return_sequences=True, return_state=True, name="decode_lstm")
        # Attention Layer
        self.attention = Attention()

    def call(self, enc_outputs, dec_inputs, states_inputs):
        decoder_embed = self.embedding(dec_inputs)
        dec_outputs, dec_state_h, dec_state_c = self.decoder_lstm(decoder_embed, initial_state=states_inputs)
        attention_output = self.attention([dec_outputs, enc_outputs])

        return attention_output, dec_state_h, dec_state_c

def Seq2Seq(maxlen, embedding_dim, hidden_units, vocab_size):
    """
    seq2seq model
    """
    # Input Layer
    encoder_inputs = Input(shape=(maxlen,), name="encode_input")
    decoder_inputs = Input(shape=(None,), name="decode_input")
    # Encoder Layer
    encoder = Encoder(vocab_size, embedding_dim, hidden_units)
    enc_outputs, enc_state_h, enc_state_c = encoder(encoder_inputs)
    dec_states_inputs = [enc_state_h, enc_state_c]
    # Decoder Layer
    decoder = Decoder(vocab_size, embedding_dim, hidden_units)
    attention_output, dec_state_h, dec_state_c = decoder(enc_outputs, decoder_inputs, dec_states_inputs)
    # Dense Layer
    dense_outputs = Dense(vocab_size, activation='softmax', name="dense")(attention_output)
    # seq2seq model
    model = Model(inputs=[encoder_inputs, decoder_inputs], outputs=dense_outputs)

    return model

def read_vocab(vocab_path):
    vocab_words = []
    with open(vocab_path, "r", encoding="utf8") as f:
        for line in f:
            vocab_words.append(line.strip())
    return vocab_words

maxlen = 10
embedding_dim = 50
hidden_units = 128
maxlen = 10
vocab_words = read_vocab("data/ch_word_vocab.txt")
special_words = ["<PAD>", "<UNK>", "<GO>", "<EOS>"]
vocab_words = special_words + vocab_words
vocab2id = {word: i for i, word in enumerate(vocab_words)}
id2vocab = {i: word for i, word in enumerate(vocab_words)}
vocab_size = len(vocab2id)

model = Seq2Seq(maxlen, embedding_dim, hidden_units, vocab_size)
model.load_weights("data/seq2seq_attention_weights_1000.h5")
print(model.summary())


def encoder_infer(model):
    encoder_model = Model(inputs=model.get_layer('encoder').get_input_at(0),
                        outputs=model.get_layer('encoder').get_output_at(0))
    return encoder_model

encoder_model = encoder_infer(model)
print(encoder_model.summary())


def decoder_infer(model, encoder_model):
    encoder_output = encoder_model.get_layer('encoder').output[0]
    maxlen, hidden_units = encoder_output.shape[1:]

    dec_input = model.get_layer('decode_input').input
    enc_output = Input(shape=(maxlen, hidden_units), name='enc_output')
    dec_input_state_h = Input(shape=(hidden_units,), name='input_state_h')
    dec_input_state_c = Input(shape=(hidden_units,), name='input_state_c')
    dec_input_states = [dec_input_state_h, dec_input_state_c]

    decoder = model.get_layer('decoder')
    dec_outputs, out_state_h, out_state_c = decoder(enc_output, dec_input, dec_input_states)
    dec_output_states = [out_state_h, out_state_c]

    decoder_dense = model.get_layer('dense')
    dense_output = decoder_dense(dec_outputs)

    decoder_model = Model(inputs=[enc_output, dec_input, dec_input_states],
                          outputs=[dense_output] + dec_output_states)
    return decoder_model


decoder_model = decoder_infer(model, encoder_model)
print(decoder_model.summary())






def infer_predict(input_text, encoder_model, decoder_model):
    text_words = input_text.split()[:maxlen]
    # print(text_words)
    input_id = [vocab2id[w] if w in vocab2id else vocab2id["<UNK>"] for w in text_words]
    input_id = [vocab2id["<GO>"]] + input_id + [vocab2id["<EOS>"]]
    # print(input_id)
    if len(input_id) < maxlen:
        input_id = input_id + [vocab2id["<PAD>"]] * (maxlen - len(input_id))

    input_source = np.array([input_id])
    input_target = np.array([vocab2id["<GO>"]])

    # 编码器encoder预测输出
    enc_outputs, enc_state_h, enc_state_c = encoder_model.predict([input_source])
    dec_inputs = input_target
    dec_states_inputs = [enc_state_h, enc_state_c]

    result_id = []
    result_text = []
    for i in range(maxlen):
        # 解码器decoder预测输出
        dense_outputs, dec_state_h, dec_state_c = decoder_model.predict([enc_outputs, dec_inputs] + dec_states_inputs)
        pred_id = np.argmax(dense_outputs[0][0])
        result_id.append(pred_id)
        result_text.append(id2vocab[pred_id])
        if id2vocab[pred_id] == "<EOS>":
            break
        dec_inputs = np.array([[pred_id]])
        dec_states_inputs = [dec_state_h, dec_state_c]
    return result_id, result_text

def predict_diag(input_text):
    result_id, result_text = infer_predict(input_text, encoder_model, decoder_model)
    if len(result_text) == 1:
        return "[无话可说]"
    else:
        return "".join(result_text[:-1])

if __name__ == "__main__":
    print("Input: ", "你 好")
    print("Output: ", predict_diag("你 好"))
    print("Input: ", "听 不懂 你 说 啥")
    print("Output: ", predict_diag("听 不懂 你 说 啥"))
    print("Input: ", "能 和 你 一起 看 电影 吗")
    print("Output: ", predict_diag("能 和 你 一起 看 电影 吗"))